# -*- coding: utf-8 -*-
# @Author: lidongdong
# @time  : 19-1-24 下午9:29
# @file  : prepare_vocab.py

import json

SOS_ID = 0
EOS_ID = 1
UNK_ID = 2
PAD_ID = 3

SOS = "<sos>"
EOS = "<eos>"
UNK = "<unk>"
PAD = "<pad>"


def create_vocab_file(json_filename, new_vocab_file):
    vocab_dict = dict()
    vocab_dict[SOS] = SOS_ID
    vocab_dict[EOS] = EOS_ID
    vocab_dict[UNK] = UNK_ID
    vocab_dict[PAD] = PAD_ID

    with open(json_filename) as f:
        temp = json.load(f)

    content = temp["content"]
    caption_split_items = [c["caption_splits"] for c in content]

    word_vocab = {}

    for caption_splits in caption_split_items:
        for caption_split in caption_splits:
            for word in caption_split:
                word = word.lower()
                if word not in word_vocab:
                    word_vocab[word] = 0
                word_vocab[word] = word_vocab[word] + 1

    word_items = word_vocab.items()
    word_items = sorted(word_items, key=lambda x: x[1], reverse=True)
    print len(word_items)

    for i, (w, n) in enumerate(word_items):
        print i, w, n
        vocab_dict[w] = i + 4

    with open(new_vocab_file, "w") as f:
        json.dump(vocab_dict, f, indent=4)

    print "Finished."


if __name__ == '__main__':
    create_vocab_file("../data/train.json", "../data/vocab.json")
